{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 参数管理\n",
    "\n",
    "一旦我们选择了架构并设置了超参数，我们就进入了训练阶段。此时，我们的目标是找到使损失函数最小化的参数值。经过训练后，我们将需要使用这些参数来做出未来的预测。此外，有时我们希望提取参数，以便在其他环境中复用它们，将模型保存到磁盘，以便它可以在其他软件中执行，或者为了获得科学的理解而进行检查。\n",
    "\n",
    "大多数情况下，我们可以忽略声明和操作参数的具体细节，而只依靠深度学习框架来完成繁重的工作。然而，当我们离开具有标准层的层叠架构时，我们有时会陷入声明和操作参数的麻烦中。在本节中，我们将介绍以下内容：\n",
    "\n",
    "* 访问参数，用于调试、诊断和可视化。\n",
    "* 参数初始化。\n",
    "* 在不同模型组件间共享参数。\n",
    "\n",
    "我们首先关注具有单隐藏层的多层感知机。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public SequentialBlock getNet() {\n",
    "    SequentialBlock net = new SequentialBlock();\n",
    "    net.add(Linear.builder().setUnits(8).build());\n",
    "    net.add(Activation.reluBlock());\n",
    "    net.add(Linear.builder().setUnits(1).build());\n",
    "    return net;\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "SequentialBlock net = getNet();\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "\n",
    "NDArray x = manager.randomUniform(0, 1, new Shape(2, 4));\n",
    "net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"lin-reg\");\n",
    "model.setBlock(net);\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "\n",
    "predictor.predict(new NDList(x)).singletonOrThrow(); // forward computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数访问\n",
    "\n",
    "我们来看一下如何从已有模型中访问参数。对于存在嵌套块的复杂模型，我们需要递归整个树来提取每个子块的参数。`DJL` 提供了 `Block.getParameters()` 函数来简化参数的访问。这是的我们可以通过索引或参数的名称来访问模型的任意参数。这就像模型是一个表一样。每层的参数都在其属性中。如下所示，我们可以检查第二个全连接层的参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ParameterList params = net.getParameters();\n",
    "// Print out all the keys (unique!)\n",
    "for (var pair : params) {\n",
    "    System.out.println(pair.getKey());\n",
    "}\n",
    "\n",
    "// Use the unique key to access the Parameter\n",
    "NDArray dense0Weight = params.get(\"01Linear_weight\").getArray();\n",
    "NDArray dense0Bias = params.get(\"01Linear_bias\").getArray();\n",
    "\n",
    "// Use indexing to access the Parameter\n",
    "NDArray dense1Weight = params.valueAt(2).getArray();\n",
    "NDArray dense1Bias = params.valueAt(3).getArray();\n",
    "\n",
    "System.out.println(dense0Weight);\n",
    "System.out.println(dense0Bias);\n",
    "\n",
    "System.out.println(dense1Weight);\n",
    "System.out.println(dense1Bias);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输出的结果告诉我们一些重要的事情。首先，这个全连接层包含两个参数，分别是该层的权重和偏置。两者都存储为单精度浮点数（float32）。注意，参数名称允许我们唯一地标识每个参数，即使在包含数百个层的网络中也是如此。\n",
    "\n",
    "### 目标参数\n",
    "\n",
    "注意，每个参数都表示为参数 `Parameter` 类的一个实例。参数是复合的对象，包含值、梯度和额外信息。除了值之外，我们还可以访问每个参数的梯度。由于我们还没有调用这个网络的反向传播，所以参数的梯度处于初始状态。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense0Weight.getGradient();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 从嵌套块收集参数\n",
    "\n",
    "让我们看看，如果我们将多个块相互嵌套，参数命名约定是如何工作的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public SequentialBlock block1() {\n",
    "    SequentialBlock net = new SequentialBlock();\n",
    "    net.add(Linear.builder().setUnits(8).build());\n",
    "    net.add(Activation.reluBlock());\n",
    "    net.add(Linear.builder().setUnits(1).build());\n",
    "    net.add(Activation.reluBlock());\n",
    "    return net;\n",
    "}\n",
    "\n",
    "public SequentialBlock block2() {\n",
    "    SequentialBlock net = new SequentialBlock();\n",
    "    for (int i = 0; i < 4; i++) {\n",
    "        net.add(block1());\n",
    "    }\n",
    "    return net;\n",
    "}\n",
    "\n",
    "SequentialBlock rgnet = new SequentialBlock();\n",
    "rgnet.add(block2());\n",
    "rgnet.add(Linear.builder().setUnits(10).build());\n",
    "rgnet.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "rgnet.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"rgnet\");\n",
    "model.setBlock(rgnet);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "\n",
    "predictor.predict(new NDList(x)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们已经设计了网络，让我们看看它是如何组织的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rgnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以用列表的方式打印出所有的参数名称："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (var pair : rgnet.getParameters()) {\n",
    "    System.out.println(pair.getKey());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为层是分层嵌套的，所以我们也可以像通过嵌套列表索引一样访问它们。例如，我们下面访问第一个主要的块，其中第二个子块的第一层的偏置项。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Block majorBlock1 = rgnet.getChildren().get(0).getValue();\n",
    "Block subBlock2 = majorBlock1.getChildren().valueAt(1);\n",
    "Block linearLayer1 = subBlock2.getChildren().valueAt(0);\n",
    "NDArray bias = linearLayer1.getParameters().valueAt(1).getArray();\n",
    "bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数初始化\n",
    "\n",
    "我们知道了如何访问参数，现在让我们看看如何正确地初始化参数。我们在 :numref:`sec_numerical_stability` 中讨论了良好初始化的必要性。`DJL` 提供多种初始化类(`Initializer`), 也允许创建自定义初始化方法。\n",
    "\n",
    "对每个类型的参数，`DJL` 提供不同的默认初始化的方法。例如：权重参数默认使用 `XavierInitializer` 初始化，而将偏置参数设置为0。\n",
    "\n",
    "\n",
    "### 内置初始化类\n",
    "\n",
    "让我们首先调用内置的初始化器，且将偏置参数设置为0。\n",
    "\n",
    "我们先来看一下如何将所有参数初始化为给定的常数（比如42），我们可以使用内置的 `ConstantInitializer` 类。`ConstantInitializer` 初始化器会把所有权重参数（类型为 `WEIGHT`）初始值设为 42。\n",
    "\n",
    "注意：如果参数已经被初始化，`setInitializer()` 函数并不会对参数有任何影响， 下面的代码不会把权重重置为 42。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.setInitializer(new ConstantInitializer(42), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, new Shape(2, 4));\n",
    "Block linearLayer = net.getChildren().get(0).getValue();\n",
    "NDArray weight = linearLayer.getParameters().get(0).getValue().getArray();\n",
    "weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们创建一个新的块，使用同样的代码，这次所有的权重都被初始化为 42:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = getNet();\n",
    "net.setInitializer(new ConstantInitializer(42), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, new Shape(2, 4));\n",
    "Block linearLayer = net.getChildren().get(0).getValue();\n",
    "NDArray weight = linearLayer.getParameters().get(0).getValue().getArray();\n",
    "weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量 (`NormalInitializer`)："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = getNet();\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, new Shape(2, 4));\n",
    "Block linearLayer = net.getChildren().valueAt(0);\n",
    "NDArray weight = linearLayer.getParameters().valueAt(0).getArray();\n",
    "weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以直接访问参数，调用参数类的 `Parameter.setInitializer()` 函数。\n",
    "\n",
    "\n",
    "下面我们使用 `XavierInitializer` 初始化方法初始化第一层，然后第二层初始化为常量值 1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = getNet();\n",
    "ParameterList params = net.getParameters();\n",
    "\n",
    "params.get(\"01Linear_weight\").setInitializer(new NormalInitializer());\n",
    "params.get(\"03Linear_weight\").setInitializer(Initializer.ONES);\n",
    "\n",
    "net.initialize(manager, DataType.FLOAT32, new Shape(2, 4));\n",
    "\n",
    "System.out.println(params.valueAt(0).getArray());\n",
    "System.out.println(params.valueAt(2).getArray());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 自定义初始化\n",
    "\n",
    "有时，`DJL` 没有提供我们需要的初始化方法。在下面的例子中，我们使用以下的分布为任意权重参数$w$定义初始化方法：\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "    w \\sim \\begin{cases}\n",
    "        U(5, 10) & \\text{ with probability } \\frac{1}{4} \\\\\n",
    "            0    & \\text{ with probability } \\frac{1}{2} \\\\\n",
    "        U(-10, -5) & \\text{ with probability } \\frac{1}{4}\n",
    "    \\end{cases}\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "在这里，我们定义了`Initializer`类的子类。我们只需要实现 `initialize` 函数，该函数接受 `NDManager`, `Shape` 和 `DataType` 参数:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyInit implements Initializer {\n",
    "\n",
    "    public MyInit() {}\n",
    "\n",
    "    @Override\n",
    "    public NDArray initialize(NDManager manager, Shape shape, DataType dataType) {\n",
    "        System.out.printf(\"Init %s\\n\", shape.toString());\n",
    "        // Here we generate data points \n",
    "        // from a uniform distribution [-10, 10]\n",
    "        NDArray data = manager.randomUniform(-10, 10, shape, dataType);\n",
    "        // We keep the data points whose absolute value is >= 5\n",
    "        // and set the others to 0.\n",
    "        // This generates the distribution `w` shown above.\n",
    "        NDArray absGte5 = data.abs().gte(5); // returns boolean NDArray where \n",
    "                                             // true indicates abs >= 5 and\n",
    "                                             // false otherwise\n",
    "        return data.mul(absGte5); // keeps true indices and sets false indices to 0.\n",
    "                                  // special operation when multiplying a numerical\n",
    "                                  // NDArray with a boolean NDArray\n",
    "    }\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = getNet();\n",
    "net.setInitializer(new MyInit(), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "Block linearLayer = net.getChildren().valueAt(0);\n",
    "NDArray weight = linearLayer.getParameters().valueAt(0).getArray();\n",
    "weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意，我们始终可以直接设置参数。\n",
    "\n",
    "高级用户请注意：不能在`GarbageCollector`范围内调整参数，以避免误导自动微分机制。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray weightLayer = net.getChildren().valueAt(0)\n",
    "    .getParameters().valueAt(0).getArray();\n",
    "weightLayer.addi(7);\n",
    "weightLayer.divi(9);\n",
    "weightLayer.set(new NDIndex(0, 0), 2020); // set the (0, 0) index to 2020\n",
    "weightLayer;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数绑定\n",
    "\n",
    "有时我们希望在多个层间共享参数。让我们看看如何优雅地做这件事。在下面，我们定义一个稠密层，然后使用它的参数来设置另一个层的参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "\n",
    "// 我们需要给共享层一个名称，以便可以引用它的参数。\n",
    "Block shared = Linear.builder().setUnits(8).build();\n",
    "SequentialBlock sharedRelu = new SequentialBlock();\n",
    "sharedRelu.add(shared);\n",
    "sharedRelu.add(Activation.reluBlock());\n",
    "\n",
    "net.add(Linear.builder().setUnits(8).build());\n",
    "net.add(Activation.reluBlock());\n",
    "net.add(sharedRelu);\n",
    "net.add(sharedRelu);\n",
    "net.add(Linear.builder().setUnits(10).build());\n",
    "\n",
    "NDArray x = manager.randomUniform(-10f, 10f, new Shape(2, 20), DataType.FLOAT32);\n",
    "\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "\n",
    "model.setBlock(net);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "System.out.println(predictor.predict(new NDList(x)).singletonOrThrow());\n",
    "\n",
    "// Check that the parameters are the same\n",
    "NDArray shared1 = net.getChildren().valueAt(2)\n",
    "    .getParameters().valueAt(0).getArray();\n",
    "NDArray shared2 = net.getChildren().valueAt(3)\n",
    "    .getParameters().valueAt(0).getArray();\n",
    "shared1.eq(shared2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个例子表明第二层和第三层的参数是绑定的。它们不仅值相等，而且由相同的张量表示。因此，如果我们改变其中一个参数，另一个参数也会改变。你可能会想，当参数绑定时，梯度会发生什么情况？答案是由于模型参数包含梯度，因此在反向传播期间第二个隐藏层和第三个隐藏层的梯度会加在一起。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 我们有几种方法可以访问、初始化和绑定模型参数。\n",
    "* 我们可以使用自定义初始化方法。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 使用 :numref:`sec_model_construction` 中定义的`FancyMLP`模型，访问各个层的参数。\n",
    "1. 查看初始化模块文档以了解不同的初始化方法。\n",
    "1. 构建包含共享参数层的多层感知机并对其进行训练。在训练过程中，观察模型各层的参数和梯度。\n",
    "1. 为什么共享参数是个好主意？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
